source('helper_functions.R')

## Which standard errors in the Angrist-Pischke simulation to compute by default
APDefault.se <- ConvertSETypes(list("homo"=c("N","t"),
                   "HC0"=c("N","t","wild","wild0","wild0s"),
                   "HC2"=c("N","t","wild","wild0","wild0s","Welch","Opt","BM"),
                   "HC3"=c("N","t"),
                   "max0"="N",
                   "max2"="N"))

## Sums of squared residuals, if matrix it's nsim x nd
SSR <- function(Y) UseMethod("SSR")     # generic
SSR.default <- function(Y) sum((Y-mean(Y))^2)
SSR.matrix <- function(Y) rowSums((Y-rowMeans(Y))^2) # SSR for each column

## Rescale variance in Behrens-Fisher problem
BFScale <- function(nd, Vhat)
    switch(Vhat, HC0=nd^2, HC2=nd*(nd-1), HC3=(nd-1)^2)

## Calculate standard errors in Behrens-Fisher problem
BFHatSE <- function(s0, s1, n0, n1, Vhat)
    switch(Vhat,
           homo=sqrt((1/n0+1/n1)*(s0+s1)/(n0+n1-2)),
           HC0=,
           HC2=,
           HC3=sqrt(s0/BFScale(n0, Vhat)+s1/BFScale(n1, Vhat)),
           max0=pmax(BFHatSE(s0,s1,n0,n1,"homo"), BFHatSE(s0,s1,n0,n1,"HC0")),
           max2=pmax(BFHatSE(s0,s1,n0,n1,"homo"), BFHatSE(s0,s1,n0,n1,"HC2")))

## Compute adjusted standard errors
BFAdjSE <- function(t.numer, s0, s1, n0, n1, Vhat, dof, ratio=NA) {
    DFWelch  <- function(sig0, sig1){
        K <- (sig0/n0) / (sig0/n0 + sig1/n1)
        1/(K^2/(n0-1) + ((1-K)^2)/(n1-1))
    }

    hat.df <- switch(dof,
                     N=Inf,
                     t=n0+n1-2,
                     Opt=DFWelch(ratio^2,1),
                     Welch=DFWelch(s0/(n0-1),s1/(n1-1)),
                     BM=DFWelch(1,1))

    cv <- qt(0.975,df=hat.df)
    se <- BFHatSE(s0, s1, n0, n1, Vhat)

    list(adj.t=abs(t.numer)/se*1.96/cv, adj.se=se*cv/qnorm(0.975), df=hat.df)
}

BFDrawErrors <- function(nsim, nd, sigmad, distr) {
    ## Draw epsilon_d  in an n_d x nsim matrix
    ep <- switch(distr,
                 logNormal=(exp(rnorm(nd*nsim))-exp(1/2)) / sqrt(exp(2)-exp(1)),
                 Normal=rnorm(nd*nsim),
                 t=rt(nd*nsim, 3)/sqrt(3))

    sigmad*matrix(ep, nrow=nsim)
}

BFSize <- function(nsim, nsimb, se.types,
                   dc=list(n0=27, n1=3, ratio=1/2, distr0="Normal", distr1="Normal")) {
    ## Compute test size and standard erorrs in the Behrens-Fisher problem
    ## ratio is ratio of standard deviations. dc are design constants.

    beta0 <- 0

    set.seed(1)
    Y0 <- BFDrawErrors(nsim, dc$n0, dc$ratio, dc$distr0)
    Y1 <- beta0 + BFDrawErrors(nsim, dc$n1, 1, dc$distr1)

    ## vectors of point estimates and SSRs
    s0 <- SSR(Y0)
    s1 <- SSR(Y1)
    t.numer <- rowMeans(Y1) - rowMeans(Y0) - beta0

    SimResults <- function(se) {
        if (grepl("wild", se[2])) {     # Bootstrap
            s <- sapply(seq(nsimb), function(j)
                BFWildBS(Y0[j,], Y1[j,], se[1], se[2], beta0))
            s <- data.frame(adj.t=unlist(t(s)[,1]),adj.se=unlist(t(s)[,2]))
        } else {
            s <- BFAdjSE(t.numer, s0, s1, dc$n0, dc$n1, se[1], se[2], dc$ratio)
        }
        ret <- c(100*mean(s$adj.t<1.96), median(s$adj.se), quantile(s$adj.se, 0.05),
                 quantile(s$adj.se, 0.90), quantile(s$adj.se, 0.95),
                 if(!is.null(s$df)) c(mean(s$df), sd(s$df)) else c(NA,NA))
        names(ret) <- c("Cov", "MedL", "5%L", "90%", "95%L", "mdof", "vardof")
        ret
    }

    cbind(se.types, t(apply(se.types, 1, SimResults)))
}

BFWildBS <- function(Y0, Y1, Vhat, dof, beta0=0, nboot=999, adj.se=TRUE){
    ## Compute test of H0:E[Y(1)-Y(0)]=beta0 in the Behrens-Fisher problem based on
    ## Wild Bootstrap.

    n0 <- length(Y0)
    n1 <- length(Y1)
    d <- mean(Y1)-mean(Y0)              # Point estimate

    se <- BFHatSE(SSR(Y0), SSR(Y1), n0, n1, Vhat)

    ## Draw Rademacher random variables, n0-by-M and n1-by-M
    v0m <- matrix(sample(c(-1,1), nboot*n0, replace=TRUE), nrow=n0)
    v1m <- matrix(sample(c(-1,1), nboot*n1, replace=TRUE), nrow=n1)

    TStat <- function(Y0, Y1, beta0)
        (rowMeans(Y1)-rowMeans(Y0)-beta0) / BFHatSE(SSR(Y0), SSR(Y1), n0, n1, Vhat)

    if (dof=="wild") {
        ## Boostrap Draws {nboot x nd}
        Y0m <- mean(Y0) + t((Y0-mean(Y0))*v0m)
        Y1m <- mean(Y1) + t((Y1-mean(Y1))*v1m)

        ## Boostrap critical value, symmetric
        tm <- TStat(Y0m, Y1m, d)
        cv <- quantile(abs(tm),0.95)

        return(list(adj.t=abs((d-beta0)/se)*1.96/cv,
                    adj.se= cv*se / qnorm(0.975)))
    } else {
        Testb0 <- function(b0) {         # Wild Bootstrap test of beta=b0
            mY <- mean(c(Y1,Y0))            # overall mean of Y
            mD <- n1/(n1+n0)                # overall mean of D

            Y0m <- mY-mD*b0 + t((Y0-mY+mD*b0)*v0m)
            Y1m <- mY+(1-mD)*b0 + t((Y1-mY+(mD-1)*b0)*v1m)
            tm <- TStat(Y0m, Y1m, b0)

            t.data <- (d-b0)/se
            if (grepl("s", dof)) {      # symmetric
                return(abs(t.data)*1.96/quantile(abs(tm), 0.95))
            } else {
                ## Express
                ## ((t.data > quantile(tm, 0.975)) | (t.data < quantile(tm, 0.025)))
                ## as an adjusted t-statistic
                return(max(t.data-quantile(tm, 0.975)+1.96,
                           quantile(tm, 0.025)-t.data+1.96))
            }
        }

        x <- if (!is.null(adj.se)) {
                 FindMinInterval(function(b0) Testb0(b0)>1.96, beta0+c(-5,5), tol=0.01)
             } else {
                 c(0,0)
             }

        return(list(adj.t=Testb0(beta0), adj.se=((x[2]-x[1]) /(2*qnorm(0.975)))))
    }
}

APSimulation <- function(nsim, nsimb, se.types=APDefault.se,
                         dc=list(n0=27, n1=3, distr0="Normal", distr1="Normal"),
                         printDetail=FALSE) {
    ## Simulations based on Angrist-Pischke design. Return data frame with results

    ratio <- c(0.5, 0.85, 1, 1.18, 2) # Ratio of standard deviations

    l <- lapply(ratio, function(r)
        format(BFSize(nsim, nsimb, se.types, dc=c(dc,list(ratio=r))),
               digits=2, nsmall=1))
    if(printDetail==TRUE) print(l)

    s <- do.call(cbind,l)
    list(cis=cbind(se.types,s[,names(s)=='Cov' | names(s) == 'MedL']),
         dof=cbind(se.types,s[,names(s)=='mdof']))
}
